
module cenf

using JLD, JuMP, MAT, Roots
using KNITRO
using Gadfly
using MathOptInterface
using CSV
using SparseArrays, LinearAlgebra, Statistics
using Logging
using DelimitedFiles
using SpecialFunctions
using DataFrames

global isocodes, indices, lhs, z_gm, z_f, δ, δt, countrynames, countrynames_full, sectornames, consumptionShares

global const CONST_DIM_ETA = 2

export solveCountry, readData, solveWithStartingβη_gm, solveWithStartingβη_f, eval_ces_problem!

export z_gm, z_f, δ, δt

# readData(): read data file (2016 version)
function readData(root::String)

  global isocodes, indices, lhs, z_gm, z_f, δ, δt, countrynames, countrynames_full, sectornames, consumptionShares

  # File format:
  # countrycode upstream downstream fraction delta delta_withtime n_div_gm n_div_f
  data = CSV.read("$(root)/data/dataset.csv")

  # Initialize
  isocodes=zeros(Int64,size(data,1))
  indices=zeros(Int64,999)
  lhs=zeros(Float64,109,35,35)
  z_gm=zeros(Float64,109,35,35)
  z_f=zeros(Float64,109,35,35)
  δ=zeros(Float64,109,35,35)
  δt=zeros(Float64,109,35,35)

  downstream = data[:,3]
  upstream = data[:,2]

  oldisocode = 0
  index = 0

  for i=1:size(data,1)
    if oldisocode!=data[i,1]
      index+=1
    end
    # data format:
    # countrycode upstream downstream fraction delta delta_withtime n_div_gm n_div_f
    # ISO code as a function of the index
    isocodes[index]=data[i,1]
    # indices as a function of the ISO code
    indices[round(Int64,data[i,1])]=index
    # we set all matrix variables so that the downstream sector
    # is the row index (first), and the upstream sector is the
    # column index (second), to be more in line with the paper
    # (notation X_ni/X_n)
    lhs[index,downstream[i],upstream[i]]=data[i,4]
    # truncate δ at 1
    δ[index,downstream[i],upstream[i]] = data[i,5] < 1.0 ? data[i,5] : 1.0

    δt[index,downstream[i],upstream[i]]=data[i,6]
    z_gm[index,downstream[i],upstream[i]]=data[i,7]
    z_f[index,downstream[i],upstream[i]]=data[i,8]

    oldisocode=data[i,1];
  end

  # normalize z_gm
  # for c = 1:109
  #   z_gm[c,:,:] = z_gm[c,:,:] ./ std(vec(z_gm[c,:,:]))  
  #   z_f[c,:,:] = z_f[c,:,:] ./ std(vec(z_f[c,:,:]))  
  # end
  # # add a small number to remove the numerical instability at z=0
  # z_gm .+= 0.001
  # z_f .+= 0.001
  
  # ----------------------------------------------
  # READ COUNTRY NAMES ETC
  # ----------------------------------------------
  file = matread("$(root)/data/countrynames.mat")
  countrynames = file["countrynames"]
  file = matread("$(root)/data/countrynames_full.mat")
  countrynames_full = file["countrynames_full"]
  file = matread("$(root)/data/sectornames.mat")
  sectornames = file["sectornames"]

  # ----------------------------------------------
  # READ CONSUMPTION SHARE DATA
  # ----------------------------------------------
  cshares = CSV.read("$(root)/data/consumptionshares.csv", header=false)
  iso=cshares[:,1]
  sector=cshares[:,2]
  consumptionShares = zeros(Float64,109,35)
  for i=1:size(cshares, 1)
    consumptionShares[indices[iso[i]],sector[i]]=cshares[i,3];
  end

end

function eval_ces_problem!(fixedParameters::Dict{String,Array}, startingParameters::Dict{String,Array}, outputParameters::Dict{Any,Any}, settings::Dict{String,Any})

  outstring = string("out/", string(now()))
  outfile_csv = string(outstring, ".csv")
  outfile_jld = string(outstring, ".jld")

  # omegameasure: 1=Z_GM, 2=Z_F

  CONST_DEFAULT_FTOL = 1e-8
  CONST_DEFAULT_XTOL = 1e-8

  if get(settings,"ftol",[]) != []
    ftol=get(settings,"ftol",[])
  else
    ftol=CONST_DEFAULT_FTOL
  end
  if get(settings,"xtol",[]) != []
    xtol=get(settings,"xtol",[])
  else
    xtol=CONST_DEFAULT_XTOL
  end
  if get(settings,"omegameasure",[]) == 1
    z=z_gm
  elseif get(settings,"omegameasure",[]) == 2
    z=z_f
  else
    # DEFAULT: z_gm
    println("Defaulting to using z_gm")
    z=z_gm
  end

  # number of iterations
  if get(settings,"maxiter",[]) != []
    maxiter = get(settings,"maxiter",[]);
  else
    println("Defaulting to maxiter=500.")
    maxiter = 500;
  end

  # whether we should use KNITRO or IPOPT
  if get(settings,"knitro",[]) != []
    if get(settings,"knitro",[])==1.0
      println("Using KNITRO.")
      use_knitro = 1;
    else
      println("Using IPOPT.")
      use_knitro = 0;
    end
  else
    # default to IPOPT
    println("Defaulting to IPOPT.")
    use_knitro = 0;
  end

  if use_knitro==1
    m = Model(solver=KnitroSolver(
    KTR_PARAM_MAXIT=maxiter,
    KTR_PARAM_FTOL=ftol,
    KTR_PARAM_XTOL=xtol,
    KTR_PARAM_OPTTOL=1e-8,
    KTR_PARAM_ALGORITHM=1,
    KTR_PARAM_HESSOPT=1, # how to get hessians: default is exact (1)
    KTR_PARAM_BLASOPTION=0,
    KTR_PARAM_BAR_MAXCROSSIT=2,
    KTR_PARAM_LINSOLVER=6, # which linear solver. 6=Intel MKL Pardiso
    KTR_PARAM_LINSOLVER_OOC=1, # if LINSOLVER=6, use out of core option? 1=maybe
    #KTR_PARAM_PAR_NUMTHREADS=2, # number of parallel threads for BLAS
    KTR_PARAM_BAR_MURULE=2,
    KTR_PARAM_PIVOT=1e-08,
    KTR_PARAM_HONORBNDS=0,
    KTR_PARAM_OUTLEV=3,
    KTR_PARAM_OUTMODE=2,
    KTR_PARAM_OUTAPPEND=1))
  else
    error("ipopt not implemented anymore")
  end

  objects = Dict()
  ccpar = 0 # number of cross-country parameters to estimate -- count variable

  if get(fixedParameters,"ρ",[]) != []
    fixedρ=get(fixedParameters,"ρ",[])
    ρ = fixedρ[1]
  elseif get(startingParameters,"ρ",[]) != []
    @variable(m, ρ)
    startρ=get(startingParameters,"ρ",[])
    setvalue(ρ, startρ[1])
    ccpar += 1
    # impose a constraint that ρ>0
    @constraint(m, ρ>=0)
  else
    # parameter not specified!
    throw(UndefVarError(:ρ))
  end
  objects["ρ"]=ρ

  if get(fixedParameters,"θ",[]) != []
    fixedθ=get(fixedParameters,"θ",[])
    θ = fixedθ[1]
  elseif get(startingParameters,"θ",[]) != []
    @variable(m, θ)
    startθ=get(startingParameters,"θ",[])
    setvalue(θ, startθ[1])
    ccpar += 1
  else
    # parameter not specified!
    throw(UndefVarError(:θ))
  end
  objects["θ"]=θ

  if get(fixedParameters,"β",[]) != []
    β = get(fixedParameters,"β",[])
  elseif get(startingParameters,"β",[]) != []
    @variable(m, β[1:2])
    setvalue(β, get(startingParameters,"β",[]))
    ccpar += 2
  else
    # parameter not specified!
    throw(UndefVarError(:β))
  end
  objects["β"]=β

  if get(fixedParameters,"η",[]) != []
    η = get(fixedParameters,"η",[])
  elseif get(startingParameters,"η",[]) != []
    @variable(m, η[1:2])
    setvalue(η, get(startingParameters,"η",[]))
    ccpar += 2
  else
    # parameter not specified!
    throw(UndefVarError(:η))
  end
  objects["η"]=η

  if get(fixedParameters,"logT",[]) != []
    logT = get(fixedParameters,"logT",[])
  elseif get(startingParameters,"logT",[]) != []
    @variable(m, logT[1:109,1:35])
    setvalue(logT, get(startingParameters,String("logT"),[]))
  else
    # parameter not specified!
    throw(UndefVarError(:logT))
  end
  objects["logT"]=logT

  if get(fixedParameters,"logS",[]) != []
    logS = get(fixedParameters,"logS",[])
  elseif get(startingParameters,"logS",[]) != []
    @variable(m, logS[1:109,1:35])
    setvalue(logS, get(startingParameters,String("logS"),[]))
  else
    # parameter not specified!
    throw(UndefVarError(:logS))
  end
  objects["logS"]=logS

  if get(fixedParameters,"γ",[]) != []
    γ = get(fixedParameters,"γ",[])
  elseif get(startingParameters,"γ",[]) != []
    @variable(m, γ[1:35,1:35] >= 0.0)
    setvalue(γ, get(startingParameters,"γ",[]))
    for n=1:35
      @constraint(m, sum(γ[n,:])==1)
      @constraint(m, γ[n,:].>=(0))
    end
  else
    # parameter not specified!
    throw(UndefVarError(:γ))
  end
  objects["γ"]=γ

  # constraint that we dont get NaNs
  if get(settings,"betaconstraint",[]) != []
    if get(settings,"betaconstraint",[])==1.0
      println("Registering beta constraint")
      maxDelta=maximum(δ)
      @constraint(m,β[1]*maxDelta+β[2]*maxDelta*maxDelta<=1.0)
      @constraint(m,β[1]>=0.0)
    else
      # no constraint
    end
  else
    # default to no constraint
  end

  # ORIGINAL VERSION
  @NLexpression(m, expXi[c=1:109,n=1:35,i=1:35], exp(-logS[c,n]+logT[c,i]-θ*log(ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i]))))
  @NLexpression(m, pn[c=1:109,n=1:35], sum(γ[n,j]*(1+expXi[c,n,j])^((ρ-1)/θ) for j=1:35))
  @NLexpression(m, g[c=1:109,n=1:35,i=1:35], (γ[n,i]*(1+expXi[c,n,i])^((ρ-1)/θ) / pn[c,n] )/(1+1/expXi[c,n,i]))

  @NLobjective(m, Max, sum(lhs[c,n,i]*log(g[c,n,i])-g[c,n,i] for n=1:35 for i=1:35 for c=1:109))

  @time solve(m)

  fval = getobjectivevalue(m)

  for par in ["β";"η";"logT";"logS";"γ"]
    if typeof(objects[par])==JuMP.Variable || typeof(objects[par])==Array{JuMP.Variable} || typeof(objects[par])==Array{JuMP.Variable,1} || typeof(objects[par])==Array{JuMP.Variable,2}
      outputParameters[par]=value.(objects[par])
    else
      outputParameters[par]=objects[par]
    end
  end
  outputParameters["fval"]=fval[1]

  # get objective function, gradient, and hessians
  if typeof(θ)==Float64
    ESTIMATE_THETA = 0
  else
    ESTIMATE_THETA = 1
  end

  # first put the solution vector into the 'values' object
  values = zeros(Float64,ccpar+109*35*2 + 35*35)
  if ESTIMATE_THETA == 1
    values[linearindex(θ)] = outputParameters["θ"]
  else
    values[linearindex(β[1])] = outputParameters["β"][1]
    values[linearindex(β[2])] = outputParameters["β"][2]
  end
  if typeof(ρ)!=Float64
    values[linearindex(ρ)] = outputParameters["ρ"]
  end
  values[linearindex(η[1])] = outputParameters["η"][1]
  values[linearindex(η[2])] = outputParameters["η"][2]
  for c=1:109
    for i=1:35
      values[linearindex(logT[c,i])] = outputParameters["logT"][c,i]
    end
  end
  for c=1:109
    for n=1:35
      values[linearindex(logS[c,n])] = outputParameters["logS"][c,n]
    end
  end
  for n=1:35
    for i=1:35
      values[linearindex(γ[n,i])] = outputParameters["γ"][n,i]
    end
  end


  # # create evaluator object and get objective function value
  # evaluator = JuMP.NLPEvaluator(m)
  # MathProgBase.initialize(evaluator, [:Grad,:HessVec,:Hess])
  # objval = MathProgBase.eval_f(evaluator, values) # == sin(2.0) + sin(3.0)

  # # get gradient
  # ∇f = zeros(Float64,ccpar+109*35*2 + 35*35)
  # MathProgBase.eval_grad_f(evaluator, ∇f, values)

  # # get hessian
  # structure = MathProgBase.hesslag_structure(evaluator)
  # ∇2fvec = zeros(Float64,size(structure[1],1))
  # MathProgBase.eval_hesslag(evaluator, ∇2fvec, values, 1.0, zeros(Float64,ccpar+109*35*2 + 35*35))

  # hess = sparse([structure[1];structure[2]],[structure[2];structure[1]],[∇2fvec;∇2fvec] )

  # VCov = Array{Float64,2};

  # if ESTIMATE_THETA == 1
  #   if maximum([linearindex(θ);linearindex(η[1]);linearindex(η[2])])==ccpar
  #     # we're fine...
  #     try
  #       temp =  hess[1:ccpar,1:ccpar]\∇f[1:ccpar];
  #       VCov = temp*temp';
  #     catch
  #       @show hess[1:ccpar,1:ccpar]
  #       @show ∇f[1:ccpar]

  #       println("Cheap VCov matrix didn't work. Trying full...")
  #       try
  #         temp = hess\∇f
  #         VCov = temp*temp'
  #       catch
  #         println("Full VCov failed.")
  #       end
  #       #error("Problem calculating the VCov matrix")
  #     end
  #   else
  #     println("Cheap VCov matrix didn't work. Trying full...")
  #     try
  #       temp = hess\∇f
  #       VCov = temp*temp'
  #     catch
  #       println("Full VCov failed.")
  #     end
  #   end
  # else
  #   if maximum([linearindex(β[1]);linearindex(β[2]);linearindex(η[1]);linearindex(η[2])])==ccpar
  #     # we can do this
  #     # get just the covar matrix for the cc parameters
  #     try
  #       temp =  hess[1:ccpar,1:ccpar]\∇f[1:ccpar];
  #       VCov = temp*temp';
  #     catch
  #       @show hess[1:ccpar,1:ccpar]
  #       show(hess[1:ccpar,1:ccpar])
  #       @show ∇f[1:ccpar]
  #       show(∇f[1:ccpar])
  #       println("Cheap VCov matrix didn't work. Trying full...")
  #       try
  #         temp = hess\∇f
  #         VCov = temp*temp'
  #       catch
  #         println("Full VCov failed.")
  #       end
  #     end
  #   else
  #     println("Cheap VCov matrix didn't work. Trying full...")
  #     try
  #       temp = hess\∇f
  #       VCov = temp*temp'
  #     catch
  #       println("Full VCov failed.")
  #     end
  #   end
  # end

  # try
  #   outputParameters["Varθ"] = VCov[linearindex(θ),linearindex(θ)]
  # end
  # try
  #   outputParameters["Varρ"] = VCov[linearindex(ρ),linearindex(ρ)]
  # end
  # try
  #   outputParameters["Var_β1"] = VCov[linearindex(β[1]),linearindex(β[1])]
  #   outputParameters["Var_β2"] = VCov[linearindex(β[2]),linearindex(β[2])]
  # end
  # try
  #   outputParameters["Var_η1"] = VCov[linearindex(η[1]),linearindex(η[1])]
  #   outputParameters["Var_η2"] = VCov[linearindex(η[2]),linearindex(η[2])]
  # end

  # get fitted values
  fitted = zeros(Float64,109,35,35)
  XiArray = zeros(Float64,109,35,35)
  for c=1:109
    for n=1:35
      for i=1:35
        # version with a fixed θ
        XiArray[c,n,i] = -outputParameters["logS"][c,n]+outputParameters["logT"][c,i]-θ*log(ifelse(1/(1-outputParameters["β"][1]*δ[c,n,i]-outputParameters["β"][2]*δ[c,n,i]*δ[c,n,i])<1+outputParameters["η"][1]*z[c,n,i]+outputParameters["η"][2]*z[c,n,i]*z[c,n,i],1/(1-outputParameters["β"][1]*δ[c,n,i]-outputParameters["β"][2]*δ[c,n,i]*δ[c,n,i]),1+outputParameters["η"][1]*z[c,n,i]+outputParameters["η"][2]*z[c,n,i]*z[c,n,i]));
      end
    end
  end
      for c=1:109
    for n=1:35
      for i=1:35
        fitted[c,n,i]=(outputParameters["γ"][n,i]*(1+exp(XiArray[c,n,i]))^((ρ-1)/θ) / sum(outputParameters["γ"][n,j]*(1+exp(XiArray[c,n,j]))^((ρ-1)/θ) for j=1:35) )*1/(1+exp(-XiArray[c,n,i]))
      end
    end
  end
  outputParameters["pseudor2"] = cor(vec(fitted),vec(lhs))^2

  # save csv output file
  # format is countrycode downstream upstream lhs fitted
  outarray = zeros(Float64,109*35*35,5)
  idx=1
  for c=1:109
    for n=1:35
      for i=1:35
        outarray[idx,:]=[isocodes[c] n i lhs[c,n,i] fitted[c,n,i] ]
        idx=idx+1
      end
    end
  end
  try
    writedlm(outfile_csv, outarray)
  catch
    println("Could not save iteration output. Continuing...")
  end

  if typeof(θ)==Float64
    θout = θ
    βout = value.(β)
  else
    θout = value.(θ)
    βout = β
  end
  if typeof(ρ)==Float64
    ρout = ρ
  else
    ρout = value.(ρ)
  end
  # not estimating theta
  info("Exiting eval_problem with ρ=",ρout,", θ=",θout,", β=",βout,", η=", value.(η), ", fval=", fval)
  println("Exiting eval_problem with ρ=",ρout,", θ=",θout,", β=",βout,", η=", value.(η), ", fval=", fval)
  return fval
end

function eval_problem!(fixedParameters::Dict{String,Array}, startingParameters::Dict{String,Array}, outputParameters::Dict{Any,Any}, settings::Dict{String,Any})

  # omegameasure: 1=Z_GM, 2=Z_F

  CONST_DEFAULT_FTOL = 1e-8
  CONST_DEFAULT_XTOL = 1e-8

  if get(settings,"ftol",[]) != []
    ftol=get(settings,"ftol",[])
  else
    ftol=CONST_DEFAULT_FTOL
  end
  if get(settings,"xtol",[]) != []
    xtol=get(settings,"xtol",[])
  else
    xtol=CONST_DEFAULT_XTOL
  end
  if get(settings,"omegameasure",[]) == 1
    z=z_gm
  elseif get(settings,"omegameasure",[]) == 2
    z=z_f
  else
    # DEFAULT: z_gm
    println("Defaulting to using z_gm")
    z=z_gm
  end
  if (get(settings,"savefitted",[]) == 1) & (get(settings,"outfile",[]) != [])
    savefitted = 1
    outfile = get(settings,"outfile",[])
  else
    # default
    savefitted = 0
  end

  m = JuMP.Model(with_optimizer(KNITRO.Optimizer,
              maxit=500,
              opttol=1e-8,
              ftol=ftol,
              xtol=xtol,
              hessopt=1, # how to get hessians: default is exact (1)
              blasoption=0,
              bar_maxcrossit=2,
              linsolver=6,
              linsolver_ooc=1, # if LINSOLVER=6, use out of core option? 1=maybe
              bar_murule=2,
              linsolver_pivottol=1e-08,
              honorbnds=0,
              outlev=3,
              outmode=0,    # 0 = screen only, 2 = screen and knitro.log
              outappend=1,
              newpoint=0,   # 0 = off, 2= save info about new point to knitro_newpoint.log
              ms_enable=0))

  objects = Dict()

  if get(fixedParameters,"θ",[]) != []
    fixedθ=get(fixedParameters,"θ",[])
    θ = fixedθ[1]
  elseif get(startingParameters,"θ",[]) != []
    startθ=get(startingParameters,"θ",[])
    @variable(m, θ, start = startθ[1])
  else
    # parameter not specified!
    throw(UndefVarError(:θ))
  end
  objects["θ"]=θ

  if get(fixedParameters,"β",[]) != []
    β = get(fixedParameters,"β",[])
  elseif get(startingParameters,"β",[]) != []
    startβ = get(startingParameters,"β",[])
    @variable(m, β[i=1:2], start = startβ[i])
  else
    # parameter not specified!
    throw(UndefVarError(:β))
  end
  objects["β"]=β

  if get(fixedParameters,"η",[]) != []
    η = get(fixedParameters,"η",[])
  elseif get(startingParameters,"η",[]) != []
    startη = get(startingParameters,"η",[])
    @variable(m, η[i=1:CONST_DIM_ETA], start = startη[i])
  else
    # parameter not specified!
    throw(UndefVarError(:η))
  end
  objects["η"]=η

  if get(fixedParameters,"logT",[]) != []
    logT = get(fixedParameters,"logT",[])
  elseif get(startingParameters,"logT",[]) != []
    s = get(startingParameters,String("logT"),[])
    @variable(m, logT[c=1:109,i=1:35], start = s[c,i])
  else
    # parameter not specified!
    throw(UndefVarError(:logT))
  end
  objects["logT"]=logT

  if get(fixedParameters,"logS",[]) != []
    logS = get(fixedParameters,"logS",[])
  elseif get(startingParameters,"logS",[]) != []
    s = get(startingParameters,String("logS"),[])
    @variable(m, logS[c=1:109,n=1:35], start = s[c,n] )
  else
    # parameter not specified!
    throw(UndefVarError(:logS))
  end
  objects["logS"]=logS

  if get(fixedParameters,"γ",[]) != []
    γ = get(fixedParameters,"γ",[])
  elseif get(startingParameters,"γ",[]) != []
    s = get(startingParameters,"γ",[])
    @variable(m, γ[n=1:35,i=1:35] >= 0.0, start = s[n,i])
    for n=1:35
      @constraint(m, sum(γ[n,:])==1)
      @constraint(m, γ[n,:].>=(0))
    end
  else
    # parameter not specified!
    throw(UndefVarError(:γ))
  end
  objects["γ"]=γ

  # constraint that we dont get NaNs
  if get(settings,"betaconstraint",[]) != []
    if get(settings,"betaconstraint",[])==1.0
      maxDelta=maximum(δ)
      @constraint(m,β[1]*maxDelta+β[2]*maxDelta*maxDelta<=1)
      @constraint(m,β[1]>=0)
    else
      # no constraint
    end
  else
    # default to no constraint
  end

  # CONST_DIM_ETA
  #println("Starting estimation with β[1]=$(startβ[1]), β[2]=$(startβ[2]), η[1]=$(startη[1]), η[2]=$(startη[2])") #, η[3]=$(startη[3]), η[4]=$(startη[4])
  # CONST_DIM_ETA
  @NLexpression(m, g[c=1:109,n=1:35,i=1:35], γ[n,i]*1/(1+exp(logS[c,n]-logT[c,i]+θ*log(ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i])))))
  @NLobjective(m, Max, sum(lhs[c,n,i]*log(g[c,n,i])-g[c,n,i] for n=1:35 for i=1:35 for c=1:109))

  optimize!(m)

  @show termination_status(m)

  fval = objective_value(m)

  for par in ["θ";"β";"η";"logT";"logS";"γ"]
    if isa(objects[par], Float64) || isa(objects[par], Array{Float64,1}) || isa(objects[par], Array{Float64,2})
      outputParameters[par]=objects[par]
    else
      outputParameters[par]=value.(objects[par])
    end
  end
  outputParameters["fval"]=fval[1]

  # get objective function, gradient, and hessians
  if typeof(θ)==Float64
    ESTIMATE_THETA = 0
  else
    ESTIMATE_THETA = 1
  end

  # number of cross-country parameters
  if ESTIMATE_THETA == 1
    ccpar = 1 + CONST_DIM_ETA  # if estimating theta and eta
  else
    ccpar = 2 + CONST_DIM_ETA # if estimating beta and eta
  end

  # first put the solution vector into the 'values' object
  values = zeros(Float64,ccpar+109*35*2 + 35*35)
  if ESTIMATE_THETA == 1
    values[JuMP.index(θ).value] = outputParameters["θ"]
  else
    values[JuMP.index(β[1]).value] = outputParameters["β"][1]
    values[JuMP.index(β[2]).value] = outputParameters["β"][2]
  end
  for i = 1:CONST_DIM_ETA
    values[JuMP.index(η[i]).value] = outputParameters["η"][i]
  end
  for c=1:109
    for i=1:35
      values[JuMP.index(logT[c,i]).value] = outputParameters["logT"][c,i]
    end
  end
  for c=1:109
    for n=1:35
      values[JuMP.index(logS[c,n]).value] = outputParameters["logS"][c,n]
    end
  end
  for n=1:35
    for i=1:35
      values[JuMP.index(γ[n,i]).value] = outputParameters["γ"][n,i]
    end
  end

  # create evaluator object and get objective function value
  evaluator = JuMP.NLPEvaluator(m)
  MOI.initialize(evaluator, [:Grad,:HessVec,:Hess])
  objval = MOI.eval_objective(evaluator, values) # == sin(2.0) + sin(3.0)

  # get gradient
  ∇f = zeros(Float64,ccpar+109*35*2 + 35*35)
  MOI.eval_objective_gradient(evaluator, ∇f, values)

  # get hessian
  structure = MOI.hessian_lagrangian_structure(evaluator);
  ∇2fvec = zeros(Float64, size(structure,1));
  MOI.eval_hessian_lagrangian(evaluator, ∇2fvec, values, 1.0, zeros(Float64,ccpar+109*35*2 + 35*35));

  VCov = zeros(Float64, size(∇f,1),size(∇f,1));

  function dense_hessian(hessian_sparsity, V)
    I = [i for (i,j) in hessian_sparsity]
    J = [j for (i,j) in hessian_sparsity]
    raw = sparse(I, J, V)
    return Symmetric(raw + raw' - sparse(diagm(0=>diag(raw))))
  end

  hess = -dense_hessian(structure,∇2fvec)

  try
    temp = hess\∇f
    VCov = temp*temp'
  catch
    if ESTIMATE_THETA == 1
      # CONST_DIM_ETA
      if maximum([JuMP.index(θ).value;JuMP.index(η[1]).value;JuMP.index(η[2]).value])==ccpar
        # we're fine...
        try
          temp =  hess[1:ccpar,1:ccpar]\∇f[1:ccpar];
          VCov = temp*temp';
        println("Success.")
        catch
          @show hess[1:ccpar,1:ccpar]
          show(hess[1:ccpar,1:ccpar])
          @show ∇f[1:ccpar]
          show(∇f[1:ccpar])
        end
      else
        # we cant do this
        println("No way. Check the ordering of the variables. This shouldnt happen...")
      end
    else
      # CONST_DIM_ETA
      if maximum([JuMP.index(β[1]).value;JuMP.index(β[2]).value;JuMP.index(η[1]).value;JuMP.index(η[2]).value])==ccpar
        # we can do this
        # get just the covar matrix for the cc parameters
        try
          temp =  hess[1:ccpar,1:ccpar]\∇f[1:ccpar];
          VCov = temp*temp';
        catch
          @show hess[1:ccpar,1:ccpar]
          show(hess[1:ccpar,1:ccpar])
          @show ∇f[1:ccpar]
          show(∇f[1:ccpar])
        end
        println("Success.")
      else
        # no way...
        println("No way. Check the ordering of the variables. This shouldnt happen...")
      end
    end

  end

  if ESTIMATE_THETA == 1
    outputParameters["Varθ"] = VCov[JuMP.index(θ).value,JuMP.index(θ).value]
  else
    outputParameters["Var_β1"] = VCov[JuMP.index(β[1]).value,JuMP.index(β[1]).value]
    outputParameters["Var_β2"] = VCov[JuMP.index(β[2]).value,JuMP.index(β[2]).value]
  end
  for i=1:CONST_DIM_ETA
    outputParameters["Var_η$(i)"] = VCov[JuMP.index(η[i]).value,JuMP.index(η[i]).value]
  end

  # get fitted values
  fitted = zeros(Float64,109,35,35)
  dI = zeros(Float64,109,35,35)
  for c=1:109
    for n=1:35
      for i=1:35
        # CONST_DIM_ETA
        dI[c,n,i] = 1.0+outputParameters["η"][1]*z[c,n,i]+outputParameters["η"][2]*z[c,n,i]*z[c,n,i]
        fitted[c,n,i]=outputParameters["γ"][n,i]*1/(1+exp(outputParameters["logS"][c,n]-outputParameters["logT"][c,i]+outputParameters["θ"]*log(ifelse(1/(1-outputParameters["β"][1]*δ[c,n,i]-outputParameters["β"][2]*δ[c,n,i]*δ[c,n,i])< dI[c,n,i] ,1/(1-outputParameters["β"][1]*δ[c,n,i]-outputParameters["β"][2]*δ[c,n,i]*δ[c,n,i]),dI[c,n,i]))))
      end
    end
  end
  outputParameters["pseudor2"] = cor(vec(fitted),vec(lhs))^2
  if savefitted==1
    # format is countrycode downstream upstream lhs fitted
    outarray = zeros(Float64,109*35*35,5)
    idx=1
    for c=1:109
      for n=1:35
        for i=1:35
          outarray[idx,:]=[isocodes[c] n i lhs[c,n,i] fitted[c,n,i] ]
          idx=idx+1
        end
      end
    end
    writedlm(outfile, outarray)
  end

  if typeof(θ)==Float64
    # not estimating theta
    @info "Exiting eval_problem" θ outputParameters["β"] outputParameters["η"] fval
    # info("Exiting eval_problem with θ=",θ,", β=",value.(β),", η=", value.(η), ", fval=", fval)
    # println("Exiting eval_problem with θ=",θ,", β=",value.(β),", η=", value.(η), ", fval=", fval)
  else
    # probably estimating theta, but not beta
    @info "Exiting eval_problem" outputParameters["θ"] β outputParameters["η"] fval
    # info("Exiting eval_problem with θ=",value.(θ),", β=",β,", η=", value.(η), ", fval=", fval)
    # println("Exiting eval_problem with θ=",value.(θ),", β=",β,", η=", value.(η), ", fval=", fval)
  end


  return fval

end

function solveWithFixedTheta(omegameasure::Float64, startβ::Vector{Float64}, startη::Vector{Float64}, startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2})
  outputParameters = solveWithFixedTheta(omegameasure,startβ,startη,startγ,startlogT,startlogS,"")
  return outputParameters
end
function solveWithFixedTheta(omegameasure::Float64, startβ::Vector{Float64}, startη::Vector{Float64}, startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2}, saveoutfilename::String)
  fixedθ = 4.0

  fixedParameters = Dict{String,Array}("θ"=>[fixedθ])
  startingParameters = Dict{String,Array}("β"=>startβ, "η"=>startη, "γ"=> startγ, "logT"=> startlogT, "logS"=>startlogS)
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>omegameasure,"betaconstraint"=>1.0)
  if saveoutfilename!=""
    settings["savefitted"]=1
    settings["outfile"]=saveoutfilename
  end
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

function solveCESWithρθ(omegameasure::Float64,startβ::Vector{Float64}, startη::Vector{Float64})
  #fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("ρ"=>[4.0], "θ" => [4.0])
  startingParameters = Dict{String,Array}("β" => startβ,"η"=>startη, "γ"=> 1/35 .* ones(Float64,35,35), "logT"=> zeros(Float64,109,35), "logS"=>zeros(Float64,109,35))
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-8,"xtol"=>1e-8,"omegameasure"=>omegameasure,"betaconstraint"=>1.0, "knitro"=>1.0, "maxiter"=>1000)
  eval_ces_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

function solveCESWithρθ(omegameasure::Float64,startβ::Vector{Float64}, startη::Vector{Float64},startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2}, saveoutfilename::String)
  #fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("ρ"=>[4.0], "θ" => [4.0])
  startingParameters = Dict{String,Array}("β" => startβ,"η"=>startη, "γ"=> startγ, "logT"=> startlogT, "logS"=>startlogS)
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-8,"xtol"=>1e-8,"omegameasure"=>omegameasure,"betaconstraint"=>1.0, "knitro"=>1.0, "maxiter"=>1000)
  if saveoutfilename!=""
    settings["savefitted"]=1
    settings["outfile"]=saveoutfilename
  end
  eval_ces_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

function solveWithθ_gm(startθ::Vector{Float64}, startη::Vector{Float64})

  fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("β"=>fixedβ)
  startingParameters = Dict{String,Array}("θ"=>startθ, "η"=>startη, "γ"=> 1/35 .* ones(Float64,35,35), "logT"=> zeros(Float64,109,35), "logS"=>zeros(Float64,109,35))
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>1.0,"betaconstraint"=>0.0)
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end
# version with starting values for all parameters
function solveWithθ_gm(startθ::Vector{Float64}, startη::Vector{Float64}, startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2}, saveoutfilename::String)

  fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("β"=>fixedβ)
  startingParameters = Dict{String,Array}("θ"=>startθ, "η"=>startη, "γ"=> startγ, "logT"=> startlogT, "logS"=>startlogS)
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>1.0,"betaconstraint"=>0.0)
  if saveoutfilename!=""
    settings["savefitted"]=1
    settings["outfile"]=saveoutfilename
  end
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

function solveWithθ_f(startθ::Vector{Float64}, startη::Vector{Float64})

  fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("β"=>fixedβ)
  startingParameters = Dict{String,Array}("θ"=>startθ, "η"=>startη, "γ"=> 1/35 .* ones(Float64,35,35), "logT"=> zeros(Float64,109,35), "logS"=>zeros(Float64,109,35))
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>2.0,"betaconstraint"=>0.0)
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end
# version with starting values for all parameters
function solveWithθ_f(startθ::Vector{Float64}, startη::Vector{Float64}, startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2}, saveoutfilename::String)

  fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("β"=>fixedβ)
  startingParameters = Dict{String,Array}("θ"=>startθ, "η"=>startη, "γ"=> startγ, "logT"=> startlogT, "logS"=>startlogS)
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>2.0,"betaconstraint"=>0.0)
  if saveoutfilename!=""
    settings["savefitted"]=1
    settings["outfile"]=saveoutfilename
  end
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

# get the structural T and S from the 'fixed effects'
# FETI and FESN should be in levels, not logs (i.e. positive only!)
function getDeepParameters!(θ::Float64, β::Array{Float64,1}, γ::Array{Float64,2}, FETi::Array{Float64,2}, FESn::Array{Float64,2},dni::Array{Float64,3},T::Array{Float64,2}, S::Array{Float64,2})

  # we have:
  # 'logS'= FESn = log(S/μ)
  # 'logT'= FETi = log(T*p^(-θ))

  # version with market power in comments
  #σ=3.5
  α=1.0 #gamma((1-σ)./θ + 1).^(1.0 ./(1-σ));
  μ=1.0 #σ./(σ-1)

  SDIM=35

  for c=1:109
    FESnc=repeat(FESn[c,:],1,SDIM);
    FETic=repeat(FETi[c,:]',SDIM,1);
    dnic=dni[c,:,:]

    # calculate price vector
    pc = exp.(sum(γ .* log.((α ./ γ) .* (FESnc .* μ.^(-θ) + FETic .* μ.^(-θ) .* dnic .^ (-θ)) .^ (-1.0 ./ θ)),dims=2));

    # back out Ti and Sn
    S[c,:] = FESn[c,:]'.*μ.^(-θ);
    T[c,:] = FETi[c,:].*(pc.^θ);

  end
end
# CES version of the above
function getDeepParameters_ces!(ρ::Float64, θ::Float64, γ::Array{Float64,2}, FETi::Array{Float64,2}, FESn::Array{Float64,2},dni::Array{Float64,3},T::Array{Float64,2}, S::Array{Float64,2})

  # version with market power in comments
  #σ=3.5
  α=1.0 #gamma((1-σ) ./ θ + 1) .^ (1.0 ./ (1-σ));
  μ=1.0 #σ./(σ-1)

  SDIM=35
  for c=1:109
    FESnc=repeat(FESn[c,:],1,SDIM);
    FETic=repeat(FETi[c,:]',SDIM,1);
    dnic=dni[c,:,:]
    # calculate price vector
    pc = ( sum( γ.*α.^(1-ρ).*( FESnc.*μ.^(-θ) + FETic.*μ.^(-θ).*dnic.^(-θ) ).^(-(1-ρ)/θ) ,dims=2) ).^(1/(1-ρ))
    # back out Ti and Sn
    S[c,:] = FESn[c,:]'.*μ.^(-θ);
    T[c,:] = FETi[c,:].*(pc.^θ);
  end
end

# version where we pass dI as opposed to \eta and z
function solveCountries(θ::Float64, β::Vector{Float64}, dI::Array{Float64,3}, γ::Array{Float64,2}, FETi::Array{Float64,2}, FESn::Array{Float64,2}, plotfilename::String, plottitle::String)
  global isocodes, indices, lhs, δ, δt, countrynames, countrynames_full, consumptionShares

  
  outputPerCapitaBefore=zeros(Float64,109)
  outputPerCapitaAfter=zeros(Float64,109)
  outputPerCapitaAfter_US=zeros(Float64,109)
  percentageChange=zeros(Float64,109)
  percentageChange_US=zeros(Float64,109)

  ioshares_before=zeros(Float64,109,35,35)
  ioshares_after_us=zeros(Float64,109,35,35)

  # create dni
  dni = zeros(Float64,109,35,35)
  for c=1:109
    for n=1:35
      for i=1:35
        dni[c,n,i]=ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<dI[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),dI[c,n,i])
      end
    end
  end

  # convert fixed effects into structural parameters
  S=zeros(Float64,109,35)
  T=zeros(Float64,109,35)
  getDeepParameters!(θ, β, γ, FETi, FESn ,dni,T, S)

  for c=1:109
    dnic=dni[c,:,:];
    #dnius = squeeze(dniusfull[c,:,:],1)
    XniXn=zeros(Float64,35,35)
    (outputPerCapitaBefore[c],XniXn,y_before, p_before) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), dnic, consumptionShares[c,:])
    ioshares_before[c,:,:]=XniXn

    (outputPerCapitaAfter[c],XniXn,y_after,p_after) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), ones(Float64,35,35), consumptionShares[c,:])
    percentageChange[c]=(outputPerCapitaAfter[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]


    (outputPerCapitaAfter_US[c],XniXn, y_after_us, p_after_us) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), dni[106,:,:], consumptionShares[c,:])
    percentageChange_US[c]=(outputPerCapitaAfter_US[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]
    ioshares_after_us[c,:,:]=XniXn

    println("$(countrynames_full[c]),$(countrynames[c]), delta: $(δ[c,1,1]), To zero: $(percentageChange[c]), To US: $(percentageChange_US[c]), Frac dY: $(((y_after_us-y_before)/y_before)/(percentageChange_US[c])), DeltaP: $((p_after_us-p_before)/p_before))")
  end
  println("Mean: ", mean(percentageChange))
  println("Mean To US: ", mean(percentageChange_US))

  # do plot

  deltavec = δ[:,1,1]
  if plotfilename != ""
    Labels = countrynames
    xticks = [0.25, 0.5, 0.75, 1]

    # uncomment this when using Gadfly
    df = DataFrame(x = deltavec, y = 100.0 .* percentageChange_US, name = vec(countrynames))
    p=Gadfly.plot(df, x=:x, y=:y, Geom.point, label =:name, Geom.label, 
      Coord.Cartesian(xmin=0.08,xmax=1.08), Guide.xticks(ticks=xticks), 
      Guide.xlabel("Enforcement cost"), Guide.ylabel("Counterfactual welfare increase, in percent",orientation=:vertical), 
      Guide.title(plottitle))
    draw(SVG("$(plotfilename).svg",15cm, 10cm),p)
    #run(`cairosvg $(plotfilename).svg -o $(plotfilename).pdf`)
  end

  return [percentageChange_US percentageChange], ioshares_before, ioshares_after_us

end

function solveCountries(θ::Float64, β::Vector{Float64}, η::Vector{Float64}, γ::Array{Float64,2}, omegameasure::Float64, FETi::Array{Float64,2}, FESn::Array{Float64,2}, plotfilename::String, plottitle::String)

  global isocodes, indices, lhs, z_gm, z_f, δ, δt, countrynames, countrynames_full, consumptionShares

  outputPerCapitaBefore=zeros(Float64,109)
  outputPerCapitaAfter=zeros(Float64,109)
  outputPerCapitaAfter_US=zeros(Float64,109)
  percentageChange=zeros(Float64,109)
  percentageChange_US=zeros(Float64,109)

  ioshares_before=zeros(Float64,109,35,35)
  ioshares_after_us=zeros(Float64,109,35,35)

  # create dni
  if omegameasure==2.0
    # z_f
    z=z_f
  else
    # otherwise z_gm
    z=z_gm
  end
  dni = zeros(Float64,109,35,35)
  dI = zeros(Float64,109,35,35)
  for c=1:109
    for n=1:35
      for i=1:35
        # CONST_DIM_ETA
        dI[c,n,i] = 1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i]
        dni[c,n,i]=ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<dI[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),dI[c,n,i])
      end
    end
  end


  # convert fixed effects into structural parameters
  S=zeros(Float64,109,35)
  T=zeros(Float64,109,35)
  getDeepParameters!(θ, β, γ, FETi, FESn ,dni,T, S)

  for c=1:109
    dnic=dni[c,:,:];
    #dnius = squeeze(dniusfull[c,:,:],1)
    XniXn=zeros(Float64,35,35)
    (outputPerCapitaBefore[c],XniXn,y_before, p_before) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), dnic, consumptionShares[c,:])
    ioshares_before[c,:,:]=XniXn

    (outputPerCapitaAfter[c],XniXn,y_after,p_after) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), ones(Float64,35,35), consumptionShares[c,:])
    percentageChange[c]=(outputPerCapitaAfter[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]


    (outputPerCapitaAfter_US[c],XniXn, y_after_us, p_after_us) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), dni[106,:,:], consumptionShares[c,:])
    percentageChange_US[c]=(outputPerCapitaAfter_US[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]
    ioshares_after_us[c,:,:]=XniXn

    println("$(countrynames_full[c]),$(countrynames[c]), delta: $(δ[c,1,1]), To zero: $(percentageChange[c]), To US: $(percentageChange_US[c]), Frac dY: $(((y_after_us-y_before)/y_before)/(percentageChange_US[c])), DeltaP: $((p_after_us-p_before)/p_before))")
  end
  println("Mean: ", mean(percentageChange))
  println("Mean To US: ", mean(percentageChange_US))

  # do plot

  deltavec = δ[:,1,1]
  if plotfilename != ""
    Labels = countrynames
    xticks = [0.25, 0.5, 0.75, 1]

    # uncomment this when using Gadfly
    df = DataFrame(x = deltavec, y = 100.0 .* percentageChange_US, name = vec(countrynames))
    p=Gadfly.plot(df, x=:x, y=:y, Geom.point, label =:name, Geom.label, 
      Coord.Cartesian(xmin=0.08,xmax=1.08), Guide.xticks(ticks=xticks), 
      Guide.xlabel("Enforcement cost"), Guide.ylabel("Counterfactual welfare increase, in percent",orientation=:vertical), 
      Guide.title(plottitle))
    draw(SVG("$(plotfilename).svg",15cm, 10cm),p)
    #run(`cairosvg $(plotfilename).svg -o $(plotfilename).pdf`)
  end

  return [percentageChange_US percentageChange], ioshares_before, ioshares_after_us
end

# CES Version
function solveCountries_ces(ρ::Float64, θ::Float64, β::Vector{Float64}, η::Vector{Float64}, γ::Array{Float64,2}, omegameasure::Float64, FETi::Array{Float64,2}, FESn::Array{Float64,2}, plotfilename::String, plottitle::String)

  global isocodes, indices, lhs, z_gm, z_f, δ, δt, countrynames, countrynames_full, consumptionShares

  outputPerCapitaBefore=zeros(Float64,109)
  outputPerCapitaAfter=zeros(Float64,109)
  outputPerCapitaAfter_US=zeros(Float64,109)
  percentageChange=zeros(Float64,109)
  percentageChange_US=zeros(Float64,109)
  ioshares_before=zeros(Float64,109,35,35)
  ioshares_after_us=zeros(Float64,109,35,35)

  # create dni
  if omegameasure==2.0
    # z_f
    z=z_f
  else
    # otherwise z_gm
    z=z_gm
  end
  dni=zeros(Float64,109,35,35)
  for c=1:109
    for n=1:35
      for i=1:35
        dni[c,n,i]=ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i])
      end
    end
  end

  # convert fixed effects into structural parameters
  S=zeros(Float64,109,35)
  T=zeros(Float64,109,35)
  getDeepParameters_ces!(ρ, θ, γ, FETi, FESn ,dni,T, S)

  for c=1:109
    dnic=dni[c,:,:];
    #dnius = squeeze(dniusfull[c,:,:],1)
    XniXn=zeros(Float64,35,35)
    (outputPerCapitaBefore[c],XniXn) = solveCountry_ces(ρ, θ, β, η, γ, log.(T[c,:]), log.(S[c,:]), dnic, consumptionShares[c,:])
    ioshares_before[c,:,:]=XniXn

    (outputPerCapitaAfter[c],XniXn) = solveCountry_ces(ρ, θ, β, η, γ, log.(T[c,:]), log.(S[c,:]), ones(Float64,35,35), consumptionShares[c,:])
    percentageChange[c]=(outputPerCapitaAfter[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]


    (outputPerCapitaAfter_US[c],XniXn) = solveCountry_ces(ρ, θ, β, η, γ, log.(T[c,:]), log.(S[c,:]), dni[106,:,:], consumptionShares[c,:])
    percentageChange_US[c]=(outputPerCapitaAfter_US[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]
    ioshares_after_us[c,:,:]=XniXn

    println("$(countrynames_full[c]),$(countrynames[c]), delta: $(δ[c,1,1]), To zero: $(percentageChange[c]), To US: $(percentageChange_US[c])")
  end
  println("Mean: ", mean(percentageChange))
  println("Mean To US: ", mean(percentageChange_US))

  # do plot

  deltavec = δ[:,1,1]
  if plotfilename != ""
    Labels = countrynames
    xticks = [0.25, 0.5, 0.75, 1]

    # uncomment this when using Gadfly
    df = DataFrame(x = deltavec, y = 100.0 .* percentageChange_US, name = vec(countrynames))
    p=Gadfly.plot(df, x=:x, y=:y, Geom.point, label =:name, Geom.label, 
      Coord.Cartesian(xmin=0.08,xmax=1.08), Guide.xticks(ticks=xticks), 
      Guide.xlabel("Enforcement cost"), Guide.ylabel("Counterfactual welfare increase, in percent",orientation=:vertical), 
      Guide.title(plottitle))
    draw(SVG("$(plotfilename).svg",15cm, 10cm),p)
    #run(`cairosvg $(plotfilename).svg -o $(plotfilename).pdf`)
  end

  return [percentageChange_US percentageChange], ioshares_before, ioshares_after_us
end

# this function solves equilibrium given the parameters of the model. Note that logT, logS etc are
# NOT the c_i and c_n components from the estimation (these include the p_i as well!)
function solveCountry(θ::Float64, γ::Array{Float64,2}, logT::Array{Float64,1}, logS::Array{Float64,1}, dni::Array{Float64,2}, consumptionShares::Array{Float64,1})

  # logT is a 1x35 Vector
  # logS is a 35x1 Vector
  # γ is 35x35

  # calibration
  α=1.0 #gamma((1-σ) ./ θ + 1).^(1.0 ./ (1-σ));
  μ=1.0 #σ./(σ-1)

  SDIM = 35

  # set up matrices
  repS = repeat(exp.(logS),1,SDIM)
  repT = repeat(exp.(logT'),SDIM,1)

  # solve for sectoral price levels
  pvec = ones(Float64,SDIM,1)
  pvec_old = zeros(Float64,SDIM,1)
  while (norm(pvec - pvec_old)>=0.000001)
    pvec_old = pvec

    pvec = exp.(sum(γ .* log.((α./γ) .* (repS + repT .* (μ .* repeat(pvec',SDIM,1).*dni).^(-θ)).^(-1.0 ./ θ)),dims=2))
  end

  # now calculate X_ni/X_n
  XniXn = γ ./ (1.0 .+ repS./(repT .* (μ .* repeat(pvec',SDIM,1) .* dni).^(-θ)))

  # calculate profits
  function excessProfits(πOverL)
    Xnvec = (Matrix(1I, SDIM, SDIM) .- ((1.0 ./μ).*XniXn)')\(consumptionShares.*(1+πOverL))
    return (πOverL .- sum((1.0 - 1.0 ./ μ).*XniXn,dims=2)' * Xnvec)[1]
  end

  πOverL = fzero(excessProfits,0,1000)

  # check this
  priceOfHouseholds = exp.(consumptionShares'*log.(pvec./consumptionShares))[1]

  # output per capita
  outputPerCapita = (1+πOverL)/priceOfHouseholds

  return outputPerCapita, XniXn, 1+πOverL, priceOfHouseholds
end

# CES version of the above
function solveCountry_ces(ρ::Float64, θ::Float64, β::Vector{Float64}, η::Vector{Float64}, γ::Array{Float64,2}, logT::Array{Float64,1}, logS::Array{Float64,1}, dni::Array{Float64,2}, consumptionShares::Array{Float64,1})

  #σ=3.5
  α=1.0 #gamma((1-σ)./θ + 1).^(1.0 ./(1-σ));
  μ=1.0 #σ./(σ-1)
  SDIM = 35

  # set up matrices
  repS = repeat(exp.(logS),1,SDIM)
  repT = repeat(exp.(logT'),SDIM,1)

  # solve for sectoral price levels iteratively by Banach iteration
  pvec = ones(Float64,SDIM,1)
  pvec_old = zeros(Float64,SDIM,1)
  while (norm(pvec - pvec_old)>=0.000001)
    pvec_old = pvec
    pvec = ( sum( γ.*α.^(1-ρ).*( repS + repT.*(μ.*repeat(pvec',SDIM,1).*dni).^(-θ) ).^(-(1-ρ)/θ) ,dims=2) ).^(1/(1-ρ))
  end

  # now calculate X_ni/X_n
  XniXn =  ( γ .* ( α.*( repS .+ repT .* (μ .* repeat(pvec',SDIM,1) .* dni).^(-θ) ).^(-1.0/θ) ).^(1.0-ρ) )./( repeat(sum( γ .* ( α .* ( repS .+ repT .* (μ .* repeat(pvec',SDIM,1) .* dni).^(-θ) ).^(-1.0/θ) ).^(1.0-ρ),dims=2),1,SDIM) ) .* 1.0 ./(1.0 .+ repS./(repT.*(μ .* repeat(pvec',SDIM,1).*dni).^(-θ)))

# calculate profits
  function excessProfits(πOverL)
    Xnvec = (Matrix(1I, SDIM, SDIM) .- ((1.0 ./μ).*XniXn)')\(consumptionShares.*(1+πOverL))
    #excessProfits = πOverL - sum((1 - 1.0 ./(μ.*dni)).*XniXn,2)'*Xnvec
    return (πOverL .- sum((1.0 - 1.0 ./μ).*XniXn,dims=2)' * Xnvec)[1]
  end

  πOverL = fzero(excessProfits,0,1000)

  priceOfHouseholds = exp.(consumptionShares'*log.(pvec./consumptionShares))[1]

  # output per capita
  outputPerCapita = (1+πOverL)/priceOfHouseholds

  return outputPerCapita, XniXn

end


end
